import numpy as np
import quantecon as qe
from scipy.interpolate import interp1d
from scipy.optimize import minimize
from scipy.optimize import broyden1
from scipy.optimize import brentq
import matplotlib.pyplot as plt
from numpy.linalg import inv
from mpl_toolkits import mplot3d
import support_functions as sp
import pandas as pd
from numpy.polynomial.legendre import leggauss
import pickle
import time
import os
from scipy.optimize import minimize
from scipy.optimize import root
import seaborn as sns


def plot_results(res, colors, rescale=1, savename=None, tmax=15):
	dlogw_path = np.ones(len(res[0][0])) * 0.04
	titles = ['dlogw', 'dlogM', 'dlogP', 'dlogP_Y', 'dlogY', 'dlogL', 'dlogA', 'dlogmu_bar', '% price change (small)', '% price change (large)']
	fig,axs = plt.subplots(int((len(titles) + 1)/2),2,figsize=(8,1.5*int((len(titles) + 1)/2)))
	axs = axs.flatten()
	for res_ind in range(len(res)):
		dlogP_path,dlogY_path,dlogmu_bar_path,dlogL_path,dlogP_Y_path,dlogmu_small,dlogmu_large,frac_change,frac_change_small,frac_change_large = res[res_ind]
		dlogA_path = dlogY_path - dlogL_path
		dlogM_path = dlogP_Y_path + dlogY_path
		var_list = [dlogw_path, dlogM_path, dlogP_path, dlogP_Y_path, dlogY_path, dlogL_path, dlogA_path, dlogmu_bar_path, frac_change_small, frac_change_large]
		for plot_ind in range(len(titles)):
			axs[plot_ind].plot(var_list[plot_ind][:tmax] * rescale, marker='o', color=colors[res_ind], fillstyle='none')
			axs[plot_ind].set_title(titles[plot_ind])
	plt.tight_layout(rect=[0, 0, 1, 0.95])
	if savename is not None:
		#fig.savefig(savename, dpi=300, bbox_inches='tight')
		plt.close(fig)


res1 = np.load('menu_cost_wage_paths/CES_matchKimball_menu2_invfrisch5_shocksize4/final_res.npy')
res2 = np.load('menu_cost_wage_paths/Kimball_symmask_menu2_invfrisch5_shocksize4/final_res.npy')

plot_results([res1, res2], ['green', 'blue'], rescale=100, tmax=36)